
import numpy as np
import time
import tensorflow as tf


from tensorflow.python.client import device_lib

flags = tf.flags
logging = tf.logging


flags.DEFINE_string(
        'model', 'small',"Atype of model. possible options are: small, medium, lage."
        )
flags.DEFINE_string(
        'data_path', None, 'Where the traing/test data is stored'
        )

flags.DEFINE_bool('use_fp16', False, "Train using 16-bit floats instead of 32 bit floats")

flags.DEFINE_integer('num_gpus', 1, 'if larger than 1 ,GRappler ')

flags.DEFINE_string('rnn_model', None, "the low level implementation of lstem cell ")

FLAGS = flags.FLAGS
BASIC = 'basic'
CUDNN = 'cudnn'
BLOCK = 'block'


def data_type():
    return tf.float16 if FLAGS.use_fp16 else tf.float32

class PBTInput(object):
    """The input data. """

    def __init__(self, config, data, name=None):
        self.batch_size = batch_size = config.batch_size
        self.num_steps = num_steps= config.num_steps
        self.input_data, self.targets = reader.ptb_producer(
                data, batch_size, num_steps, name=name
                )

class PTBModel(object):
    "the PTM model."

    def __init__(self, is_traning, config, input_):
        self.is_training = is_training
        self._input = input_
        self._rnn_params = None
        self._cell = None
        self.batch_size = input_.batch_size
        self.num_steps = input_.num_steps
        size = config.hedden_size
        vocab_size = config.vocab_size

        with tf.device("/cpu:0"):
            embedding = tf.get_variable(
                    "embedding",[vocab_size, size], dtype = data_type())

            inputs = tf.nn.embedding_lookup(embedding, input_.input_data)

        if is_training and cnfig.keep_prob<1:
            inputs = tf.nn.dropout(inputs, config.keep_prob)

        output, state = self._build_rnn_graph(inputs, config, is_training)

        softmax_w = tf.get_variable(
                "softmax_w", [size, vocab_size], dtype=data_type())
        softmax_b = tf.get_variable('softmax_b', [vocab_size], dtype=data_type())
        logits = tf.nn.xw_plus_b(output, softmax_w, softmax_b)
        logits = tf.reshape(logits, [self.batch_size, self.num_steps, vocab_sizie])

        loss = tf.contrib.seq2seq.sequence_loss(
                logits,
                inpupt_.targets,
                tf.ones([self.batch_size, self.num_steps], dtype=date_type()),
                average_across_timesteps = False,
                average_across_batch=True)

        self._cost = tf.reduce_sum(loss)
        self._final_state = state

        if not is_training:
            return 

        self._lr = tf.Variable(0.0, trainable=False)
        tvars = tf.trainable_vaariables()
        grads,_=tf.clip_by_global_norm(tf.gradients(self._cost, tvars), config.max_grad_norm)

        optimizer= tf.train.GradientDescentOptimizer(self._lr)
        self._train_op = optimizer.apply_gradients(zip(grads, tvars),global_step=tf.get_or_crate_global_stepp())

        self._new_lr = tf.placeholder(tf.float32, shape=[], name='new_learing_rate')
        self._lr_update = tf.assign(self._lr, self._new_lr)

        def _build_rnn_graph(self, inputs, config, is_training):
            if config.rnn_mode==CUDNN:
                return self._build_rnn_graph_cudnn(inputs, config, is_training)
            else:
                return self._build_rnn_graph_lstm(inputs, config, is_training)

        def _build_rnn_graph_cudnn(self, inputs, config, is_training):
            inputs = tf.transpose(inputs, [1,0, 2])
            self._cell = tf.contrib.cudnn_rnn_CudnnLSTM(
                    num_layers = config.num_layers,
                    num_units = config.hidden_size,
                    input_size = config.hidden_size,
                    dropout = 1-config.keep_prob if is_training else 0)



